{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction to TrainConfig\n",
    "\n",
    "### Context \n",
    "\n",
    "> Warning: This is still experimental and may change during June / July 2019\n",
    "\n",
    "We introduce here the TrainConfig abstraction, a serializible wrapper to the usual setup used to run federated training: a model, a loss function, an optimizer type and training hyper parameters (batch_size, lr, ...).\n",
    "\n",
    "The main reason why using TrainConfig is to set the limits between a worker (that holds private data and performs training) and another worker that acts as a scheduler (knowns workers, has a model and demands training from this workers).\n",
    "\n",
    "Authors:\n",
    "- Marianne Monteiro - Twitter [@hereismari](https://twitter.com/hereismari) - GitHub: [@mari-linhares](https://github.com/mari-linhares)\n",
    "- Silvia - GitHub [@midokura-silvia](https://github.com/midokura-silvia)\n",
    "- Bobby Wagner - Twitter [@bobbyawagner](https://twitter.com/bobbyawagner) - GitHub: [@robert-wagner](https://github.com/robert-wagner)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Remote Training on a Federate Learning setup\n",
    "\n",
    "For a Federated Learning setup with TrainConfig we consider at least two participants:\n",
    "\n",
    "* A worker that owns a dataset.\n",
    "\n",
    "* An entity that knows the workers and the dataset name that lives in each worker. We'll call this a scheduler.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a worker\n",
    "\n",
    "Let's create a remote worker that holds some data!\n",
    "\n",
    "#### Preparation: Start the websocket worker\n",
    "\n",
    "First, we need to create a remote worker, we'll call it alice. For this, you need to run in a terminal (not possible from the notebook):\n",
    "\n",
    "```bash\n",
    "python start_worker.py --port 8777 --id alice\n",
    "```\n",
    "\n",
    "#### What's going on?\n",
    "\n",
    "Let's have a look at the main function of `start_worker.py`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import inspect\n",
    "import start_worker\n",
    "\n",
    "print(inspect.getsource(start_worker.main))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This script creates a worker and populate it with some toy data using `worker.add_dataset`, the dataset is identified by a key in this case `xor`.\n",
    "\n",
    "The scheduler needs to know the worker (alice) and its dataset (xor) so it can say: \"hey alice, here is a TrainConfig definition could you train using dataset `xor`?\"\n",
    "\n",
    "We can add multiple datasets to a single worker.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setting up a scheduler\n",
    "\n",
    "We'll use this notebook as a scheduler, for this we'll need to:\n",
    "\n",
    "* Have a model\n",
    "* Have a loss function\n",
    "* Define an optimizer\n",
    "* Define hyper-parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dependencies\n",
    "import torch as th\n",
    "import torch.nn.functional as F\n",
    "from torch import nn\n",
    "\n",
    "use_cuda = th.cuda.is_available()\n",
    "th.manual_seed(1)\n",
    "device = th.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "\n",
    "import syft as sy\n",
    "from syft import workers\n",
    "\n",
    "hook = sy.TorchHook(th)  # hook torch as always :)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model\n",
    "\n",
    "A model for TrainConfig is a regular torch model with a significant difference: it needs to be serializable. \n",
    "\n",
    "Given that, we can turn a regular torch model into a [jit](https://pytorch.org/docs/stable/jit.html) module. Jit modules use Torchscript.\n",
    "\n",
    "> Torchsript creates serializable and optimizable models from PyTorch code. Any code written in TorchScript can be saved from a Python process and loaded in a process where there is no Python dependency. This facility will allow us to send this model to remote workers. - [jit documentation](https://pytorch.org/docs/stable/jit.html)\n",
    "\n",
    "We can turn a regular module into a jit module using `th.jit.trace`. First we can implement a regular torch model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(th.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(2, 20)\n",
    "        self.fc2 = nn.Linear(20, 10)\n",
    "        self.fc3 = nn.Linear(10, 1)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can trace it using `th.jit.trace` using some mock data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate the model\n",
    "model = Net()\n",
    "\n",
    "# The data itself doesn't matter as long as the shape is right\n",
    "mock_data = th.zeros(1, 2)\n",
    "\n",
    "# Create a jit version of the model\n",
    "traced_model = th.jit.trace(model, mock_data)\n",
    "\n",
    "type(traced_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loss function\n",
    "\n",
    "The same applies to the loss function, it needs to be serializable. We can define a usual function just changing it to use jit. We can trace the function the same way we need for models or we can use a function decorator called `th.jit.script`.\n",
    "\n",
    "You can read more about jit trace and jit script in the [pytorch jit documentation](https://pytorch.org/docs/stable/jit.html#mixing-tracing-and-scripting)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loss function\n",
    "@th.jit.script\n",
    "def loss_fn(target, pred):\n",
    "    return ((target.view(pred.shape).float() - pred.float()) ** 2).mean()\n",
    "\n",
    "type(loss_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optimizer\n",
    "\n",
    "Just say which one you want to use (for now only built in torch optimizers are available)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = \"SGD\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### General hyper parameters and training options\n",
    "\n",
    "TrainConfig currently supports:\n",
    "* batch_size\n",
    "* Optimizer_args: A dict of args used to initialize the optimizer\n",
    "* epochs\n",
    "* max_nr_batches: Maximum number of training steps that will be performed. For large datasets this can be used to run for less than the number of epochs provided.\n",
    "* shuffle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 4\n",
    "optimizer_args = {\"lr\" : 0.1, \"weight_decay\" : 0.01}\n",
    "epochs = 1\n",
    "max_nr_batches = -1  # not used in this example\n",
    "shuffle = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a TrainConfig\n",
    "\n",
    "TrainConfig is just a wrapper to all we defined for the scheduler, creating a train config consists only of sendin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_config = sy.TrainConfig(model=traced_model,\n",
    "                              loss_fn=loss_fn,\n",
    "                              optimizer=optimizer,\n",
    "                              batch_size=batch_size,\n",
    "                              optimizer_args=optimizer_args,\n",
    "                              epochs=epochs,\n",
    "                              shuffle=shuffle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run training remotely\n",
    "\n",
    "Now that we have a TrainConfig instance, we can just send it to a remote worker and the worker will know how it should execute training (which model, loss function, optimizer, ... to use)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Connect to remote worker\n",
    "\n",
    "\n",
    "We'll connect to the worker (alice) that we initiated at the beginning of the tutorial. We'll instantiate a websocket client, our local access point (proxy) to the remote worker.\n",
    "Note that **this step will fail if the worker is not running**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kwargs_websocket = {\"host\": \"localhost\", \"hook\": hook, \"verbose\": False}\n",
    "alice = workers.websocket_client.WebsocketClientWorker(id=\"alice\", port=8777, **kwargs_websocket)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  Send TrainConfig to worker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Send train config\n",
    "train_config.send(alice)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can execute remote training using our TrainConfig instance!\n",
    "\n",
    "### Training remotely with TrainConfig\n",
    "\n",
    "First let's evaluate our model before training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup toy data (xor example)\n",
    "data = th.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)\n",
    "target = th.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)\n",
    "\n",
    "print(\"\\nEvaluation before training\")\n",
    "pred = model(data)\n",
    "loss = loss_fn(target=target, pred=pred)\n",
    "print(\"Loss: {}\".format(loss))\n",
    "print(\"Target: {}\".format(target))\n",
    "print(\"Pred: {}\".format(pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now train the model on alice's data.\n",
    "\n",
    "We know that alice has a dataset identified by \"xor\", so let's ask it to train using this data. Alice knows how to train because we already said it in the TrainConfig."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(10):\n",
    "    loss = alice.fit(dataset_key=\"xor\")  # ask alice to train using \"xor\" dataset\n",
    "    print(\"-\" * 50)\n",
    "    print(\"Iteration %s: alice's loss: %s\" % (epoch, loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_model = train_config.model_ptr.get()\n",
    "\n",
    "print(\"\\nEvaluation after training:\")\n",
    "pred = new_model(data)\n",
    "loss = loss_fn(target=target, pred=pred)\n",
    "print(\"Loss: {}\".format(loss))\n",
    "print(\"Target: {}\".format(target))\n",
    "print(\"Pred: {}\".format(pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the repositories! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Pick our tutorials on GitHub!\n",
    "\n",
    "We made really nice tutorials to get a better understanding of what Federated and Privacy-Preserving Learning should look like and how we are building the bricks for this to happen.\n",
    "\n",
    "- [Checkout the PySyft tutorials](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials)\n",
    "\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! \n",
    "\n",
    "- [Join slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! If you want to start \"one off\" mini-projects, you can go to PySyft GitHub Issues page and search for issues marked `Good First Issue`.\n",
    "\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "- [Donate through OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
